# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import sys


def unroll(uf, IndexType, InType, OutType, use_weights, isa):

    def sizeof(InType):
        size = 0
        if InType == "float":
            size = 4
        elif InType == "float16":
            size = 2
        elif InType == "uint8_t":
            size = 1
        else:
            assert False

        return size

    def compute(regid, InType, use_weights, isa, prefetch):
        code = []

        if InType == "float":
            code.append("vop%d = _mm256_fmadd_ps(vwgt,  \
                  _mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid))

        elif InType == "float16":
            code.append("vop%d = _mm256_fmadd_ps(vwgt,  \
                   _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \
                   vop%d);"
                        % (regid, regid, regid))
        elif InType == "uint8_t":
            code.append("vop%d = _mm256_fmadd_ps(vwgt,  \
                   _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))), \
                   _mm256_add_ps(vop%d, vbio));"
                        % (regid, regid, regid))
        else:
            assert False


        if prefetch == True:
            code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid))
        else:
            code.append("// skip unecassery prefetch of (&ip_next_T0[%d])" % (regid))

        return code

    code = []
    code.append("// unrolling " + str(uf) + " times")
    code.append(IndexType + " dataInd = 0;")
    code.append("for (" + IndexType +
                " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
    code.append(OutType + " *op = &out[rangeIndex * block_size];")
    for i in range(0, uf):
        j = 8 * i
        code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();")

    # inner loop
    code.append("for (" + IndexType +
                " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
    code.append("const  " + IndexType + " idx = indices[dataInd];")
    code.append(
        'CAFFE_ENFORCE(idx >=0 && idx < data_size, "Index ", dataInd, " is out of bounds: ", idx, ", range 0 to ", data_size);')

    if InType == "uint8_t":
        code.append(OutType + " wgt = 1.f;")
        code.append(OutType + " bio;")
        code.append("if (weights) {")
        code.append("wgt = weights[dataInd];")
        code.append("}")
        code.append("bio = wgt * scale_bias[2 * idx + 1];");
        code.append("wgt = wgt * scale_bias[2 * idx];");
        code.append("__m256 vbio = _mm256_set1_ps(bio);")
    else:
        code.append(OutType + " wgt = 1.f;")
        code.append("if (weights) {")
        code.append("wgt = weights[dataInd];")
        code.append("}")
    code.append("__m256 vwgt = _mm256_set1_ps(wgt);")

    code.append("const  " + InType + " *ip = &input[idx * block_size];")
    code.append("const  " + IndexType +
                " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
    code.append("const  " + IndexType + " idx_pref_T0 = indices[next_T0];")
    code.append(
        "CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
    code.append("const  " + InType +
                " *ip_next_T0 = &input[idx_pref_T0 * block_size];")

    for i in range(0, uf):
        j = 8 * i
        cachelinesize = 64
        byteoffset = sizeof(InType) * j
        prefetch = ((byteoffset % cachelinesize) == 0)
        code.extend(compute(j, InType, use_weights, isa, prefetch))
    code.append("}")

    code.append("if (normalize_by_lengths == false) {")
    for i in range(0, uf):
        j = 8 * i
        code.append(
            "_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
    code.append("} else if (lengths[rangeIndex]) {")
    # inv of length
    code.append(
        "__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
    for i in range(0, uf):
        j = 8 * i
        code.append(
            "_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));")
    code.append("}")

    code.append("}")
    return code


def generic(IndexType, InType, OutType, use_weights, isa):

    def compute(InType, use_weights, isa):
        code = []
        if InType == "float":
            code.append("_mm256_storeu_ps(&op[j], \
                                 _mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \
                                   );")
        elif InType == "float16":
            code.append("_mm256_storeu_ps(&op[j], \
                   _mm256_fmadd_ps(vwgt, \
                     _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \
                                   );")
        elif InType == "uint8_t":
            code.append("_mm256_storeu_ps(&op[j], \
                   _mm256_fmadd_ps(vwgt, \
                     _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(&ip[j])))), \
                     _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio) ) \
                                   );")
        else:
            assert False


        code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")

        return code

    code = []
    code.append(IndexType + " dataInd = 0;")
    code.append("for (" + IndexType +
                " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
    code.append(OutType + " *op = &out[rangeIndex * block_size];")

    # initialize to 0
    code.append("TIndex j = 0;")
    code.append("for(; j + 8 <= block_size; j += 8) {")
    code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());")
    code.append("}")
    code.append("for(; j < block_size; j++) {")
    code.append("op[j] = 0.0f;")
    code.append("}")

    # inner loop
    code.append("for (" + IndexType +
                " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
    code.append("const  " + IndexType + " idx = indices[dataInd];")
    code.append(
        'CAFFE_ENFORCE(idx >=0 && idx < data_size, "Index ", dataInd, " is out of bounds: ", idx, ", range 0 to ", data_size);')

    if InType == "uint8_t":
        code.append(OutType + " wgt = 1.f;")
        code.append(OutType + " bio;")
        code.append("if (weights) {")
        code.append("wgt = weights[dataInd];")
        code.append("}")
        code.append("assert (scale_bias);")
        code.append("bio = wgt * scale_bias[2 * idx + 1];");
        code.append("wgt = wgt * scale_bias[2 * idx];");
        code.append("__m256 vbio = _mm256_set1_ps(bio);")
    else:
        code.append(OutType + " wgt = 1.f;")
        code.append("if (weights) {")
        code.append("wgt = weights[dataInd];")
        code.append("}")
    code.append("__m256 vwgt = _mm256_set1_ps(wgt);")

    code.append("const  " + InType + " *ip = &input[idx * block_size];")
    code.append("const  " + IndexType +
                " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
    code.append("const  " + IndexType + " idx_pref_T0 = indices[next_T0];")
    code.append(
        "CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
    code.append("const  " + InType +
                " *ip_next_T0 = &input[idx_pref_T0 * block_size];")

    # compute and store main loop
    code.append("j = 0;")
    code.append("for(; j + 8 <= block_size; j += 8) {")
    code.extend(compute(InType, use_weights, isa))
    code.append("}")
    # leftover
    if InType == "float16":
        #code.append("float16 vtmp1[8] __attribute__((aligned(64)));")
        code.append("float16 vtmp1[8] CAFFE2_ALIGNED(64);")
    code.append("for(; j < block_size; j++) {")
    if InType == "float":
        code.append("op[j] += wgt * ip[j];")
    elif InType == "float16":
        code.append("vtmp1[0] = ip[j];")
        code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));")
        code.append("op[j] += wgt * ((float*)(&vtmp2))[0];")
    elif InType == "uint8_t":
        code.append("op[j] += wgt * ((float)ip[j]) + bio;")
    else:
        assert False

    code.append("}")

    code.append("}")

    code.append("if (normalize_by_lengths && lengths[rangeIndex]) {")
    code.append("float len_inv = 1.0f / lengths[rangeIndex];")
    code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);")
    code.append("j = 0;")
    code.append("for(; j + 8 <= block_size; j += 8) {")
    code.append(
        "_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));")
    code.append("}")
    code.append("for(; j < block_size; j++) {")
    code.append("op[j] = len_inv * op[j];")
    code.append("}")

    code.append("}")

    code.append("}")
    return code



# start main code
parser = argparse.ArgumentParser()
parser.add_argument('-f', nargs=1, help="file name")
opts = parser.parse_args()
filename = "embedding_lookup_avx2.cc"
if opts.f:
    filename = (opts.f)[0]
fout = open(filename, 'w')

options = [["int32_t", "float",   "float"],
           ["int64_t", "float",   "float"],
           ["int32_t", "float16", "float"],
           ["int64_t", "float16", "float"],
           ["int32_t", "uint8_t",  "float"],
           ["int64_t", "uint8_t",  "float"],
          ]

code = []
# includes
code.append(
"""/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
""")
code.append("//// --------------------------")
code.append("//// ATTENTION:                ")
code.append("//// THIS CODE IS AUTOGENERATED")
code.append("//// BY %s                     " % (sys.argv[0]))
code.append("//// DO NOT MODIFY!!!          ")
code.append("//// --------------------------\n\n")

code.append("#include \"caffe2/core/types.h\"")
code.append("#include \"caffe2/core/common.h\"")
code.append("#include <immintrin.h>")
code.append("\n")

code.append("namespace caffe2 {\n")
for o in options:
    [IndexType, InType, OutType] = o

    fn = "void EmbeddingLookup_" + IndexType + \
        "_" + InType + "_" + OutType + "__avx2_fma"
    code.append(fn + "(")
    code.append("const TIndex block_size,")
    code.append("const TIndex output_size,")
    code.append("const TIndex index_size,")
    code.append("const TIndex data_size,")
    code.append("const " + InType + "* input,")
    code.append("const " + IndexType + "* indices,")
    code.append("const int* lengths,")
    code.append("const float* weights,")
    code.append("const float* scale_bias,")
    code.append("bool normalize_by_lengths,")
    code.append(OutType + "* out)")

    code.append("{")
    code.append("const " + IndexType + " prefdist_T0 = 16;")
    #code.append("printf(\"calling " + fn + "\\n\");");
    if InType != "uint8_t":
        code.append("CAFFE_ENFORCE(scale_bias == nullptr, \"scale_bias must be nullptr\");");
    else:
        code.append("CAFFE_ENFORCE(scale_bias != nullptr, \"scale_bias must not be nullptr\");");

    code.append("if (block_size == 128) {")
    code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2"))
    code.append("} else if (block_size == 64) {")
    code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2"))
    code.append("} else if (block_size == 32) {")
    code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2"))
    code.append("} else if (block_size == 16) {")
    code.extend(unroll(2, IndexType, InType, OutType, True, "AVX2"))
    code.append("} else {")
    code.append("// generic code")
    code.extend(generic(IndexType, InType, OutType, True, "AVX2"))
    code.append("}")


    code.append("}")

    code.append("\n")
code.append("} // namespace caffe2")

for c in code:
    #print(c, file = fout)
    fout.write(c + "\n")
fout.close()


print("Created " + filename)
